{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e278e6c5-bde9-4912-a6f0-5a580a9e70b1",
   "metadata": {},
   "source": [
    "# Example 02 - TorchSig Narrowband Classifier\n",
    "This notebook walks through a simple example of how to use the clean TorchSig Narrowband Dataset and Trainer. You can train from scratch or load a pre-trained supported model, and evaluate the trained network's performance. Note that the experiment and the results herein are not to be interpreted with any significant value but rather serve simply as a practical example of how the `torchsig` dataset and tools can be used and integrated within a typical [PyTorch](https://pytorch.org/) and/or [PyTorch Lightning](https://www.pytorchlightning.ai/) workflow.\n",
    "\n",
    "----"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43f0ec09-17a4-4cd8-b1e2-196b1400e14d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TorchSig imports\n",
    "from torchsig.transforms.target_transforms import DescToClassIndex\n",
    "from torchsig.transforms.transforms import (\n",
    "    RandomPhaseShift,\n",
    "    Normalize,\n",
    "    ComplexTo2D,\n",
    "    Compose,\n",
    ")\n",
    "from torchsig.utils.narrowband_trainer import NarrowbandTrainer\n",
    "from torchsig.datasets.torchsig_narrowband import TorchSigNarrowband\n",
    "from torchsig.datasets.datamodules import NarrowbandDataModule\n",
    "import numpy as np\n",
    "import cv2\n",
    "import os\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90b68819-5d6d-4ce2-9e41-22babdaaf754",
   "metadata": {},
   "source": [
    "----\n",
    "### Instantiate TorchSigNarrowband Dataset\n",
    "Here, we instantiate the TorchSigNarrowband clean training dataset and the TorchSigNarrowband clean validation dataset. We demonstrate how to compose multiple TorchSig transforms together, using a data impairment with a random phase shift that uniformly samples a phase offset between -1 pi and +1 pi. The next transform normalizes the complex tensor, and the final transform converts the complex data to a real-valued tensor with the real and imaginary parts as two channels. We additionally provide a target transform that maps the `SignalMetadata` objects, that are part of `SignalData` objects, to a desired format for the model we will train. In this case, we use the `DescToClassIndex` target transform to map class names to their indices within an ordered class list. Finally, we sample from our datasets and print details in order to confirm functionality.\n",
    "\n",
    "For more details on the TorchSigNarrowband dataset instantiations, please see `00_example_narrowband_dataset.ipynb`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7775b0b-57a9-4870-b848-19cca6cc0084",
   "metadata": {},
   "outputs": [],
   "source": [
    "class_list = list(TorchSigNarrowband._idx_to_name_dict.values())\n",
    "num_classes = len(class_list)\n",
    "\n",
    "# Specify Transforms\n",
    "transform = Compose(\n",
    "    [\n",
    "        RandomPhaseShift(phase_offset=(-1, 1)),\n",
    "        Normalize(norm=np.inf),\n",
    "        ComplexTo2D(),\n",
    "    ]\n",
    ")\n",
    "target_transform = DescToClassIndex(class_list=class_list)\n",
    "\n",
    "datamodule = NarrowbandDataModule(\n",
    "    root='./datasets/narrowband_test_QA',\n",
    "    qa=True,\n",
    "    impaired=True,\n",
    "    transform=transform,\n",
    "    target_transform=target_transform,\n",
    "    batch_size=32,\n",
    "    num_workers=16,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "024c5f4a-7322-40cb-ba63-d17f672af032",
   "metadata": {},
   "source": [
    "---\n",
    "### Instantiate and Initialize the NarrowbandTrainer with specified parameters.\n",
    "\n",
    "    Args:\n",
    "        model_name (str): Name of the model to use.\n",
    "        num_epochs (int): Number of training epochs.\n",
    "        batch_size (int): Batch size for training.\n",
    "        num_workers (int): Number of workers for data loading.\n",
    "        learning_rate (float): Learning rate for the optimizer.\n",
    "        input_channels (int): Number of input channels into model.\n",
    "        data_path (str): Path to the dataset.\n",
    "        impaired (bool): Whether to use the impaired dataset.\n",
    "        qa (bool): Whether to use QA configuration.\n",
    "        checkpoint_path (str): Path to a checkpoint file to load the model weights.\n",
    "        datamodule (LightningDataModule): Custom data module instance.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54060a0e-4aa5-4a45-ba66-660191ceab36",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the trainer with desired parameters\n",
    "trainer = NarrowbandTrainer(\n",
    "    model_name = 'xcit',\n",
    "    num_epochs = 2,\n",
    "    # batch_size = 32, # Uncomment if not passing in datamodule\n",
    "    # num_workers = 16, # Uncomment if not passing in datamodule\n",
    "    learning_rate = 1e-3,\n",
    "    input_channels = 2,\n",
    "    # data_path = '../datasets/narrowband_test_QA', # Uncomment if not passing in datamodule\n",
    "    # impaired = True, # Uncomment if not passing in datamodule\n",
    "    # qa = False # Uncomment if not passing in datamodule\n",
    "    datamodule = datamodule,\n",
    "    checkpoint_path = None # If loading checkpoint, add path here\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7fbc155-71a2-4221-bbe8-e0e1a900945d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# View all available models\n",
    "print(trainer.available_models)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab2e7d8b-b972-46cb-b08d-bc6bbba34d76",
   "metadata": {},
   "source": [
    "---\n",
    "### Train or Fine Tune your model.\n",
    "    Can load any pytorchlightning checkpoint by providing checkpoint path above, otherwise with train specified model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1107bedd-f74c-439d-bfd1-607982fba8a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train the model\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5761819e-bb6e-44f5-9760-59758655fec0",
   "metadata": {},
   "source": [
    "---\n",
    "### Validate model\n",
    "    You can validate a model by loading its checkpoint in the intialization stage or after training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "645aa759-fb57-4806-ab1d-53d4a2085e12",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.validate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8663744b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train accuracy and loss plots\n",
    "acc_plot = cv2.imread(trainer.acc_plot_path)\n",
    "loss_plot = cv2.imread(trainer.loss_plot_path)\n",
    "\n",
    "plots = [acc_plot, loss_plot]\n",
    "\n",
    "fig = plt.figure(figsize=(21, 6))\n",
    "r = 1\n",
    "c = 3\n",
    "\n",
    "for i in range(2):\n",
    "    fig.add_subplot(r, c, i + 1)\n",
    "    plt.imshow(plots[i])\n",
    "    plt.axis('off') \n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45b6c815",
   "metadata": {},
   "outputs": [],
   "source": [
    "# confusion matrix\n",
    "cm_plot = cv2.imread(trainer.cm_plot_path)\n",
    "plt.imshow(cm_plot, aspect='auto')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9a6380b-ecfe-4cfb-9cf5-5528a5e3872e",
   "metadata": {},
   "source": [
    "---\n",
    "### Predict with model\n",
    "    You can make inferences/predictions with model by loading checkpoint in the intialization stage or after training."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f907cf1-301b-4c4c-93d7-b1e079bf3210",
   "metadata": {},
   "source": [
    "#### Load Data\n",
    "    You can load whatever data you wish, assuming it is a torch.Tensor.\n",
    "    In this example, we will load an example from our validation set\n",
    "\n",
    "    Data needs to be shape (batch_size, input_channels, data_length). You can use tensor.unsqueeze(dim=0) to add a batch dimension."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9baca30-3d19-4c84-b149-caa40ff8c867",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "datamodule.prepare_data()\n",
    "datamodule.setup(\"fit\")\n",
    "\n",
    "# Retrieve a sample and print out information to verify\n",
    "idx = np.random.randint(len(datamodule.val))\n",
    "data, label = datamodule.train[idx]\n",
    "data = torch.tensor(data).float().unsqueeze(dim=0)\n",
    "print(\"Dataset length: {}\".format(len(datamodule.val)))\n",
    "print(\"Data shape: {}\".format(data.shape))\n",
    "print(\"Label Index: {}\".format(label))\n",
    "print(\"Label Class: {}\".format(TorchSigNarrowband.convert_idx_to_name(label)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1acba2a5-a9d6-454c-8894-c6c669344be6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Predict on new data (assuming `new_data` is a torch.Tensor)\n",
    "predictions = trainer.predict(data)[0]\n",
    "print(TorchSigNarrowband._idx_to_name_dict[predictions])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
